-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR] feat(mlir-tblgen): Add support for dialect interfaces #170046
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-ods Author: AidinT (aidint) ChangesCurrently, Dialect Interfaces can't be defined in ODS. This PR adds the support for dialect interfaces. It follows the same approach with other interfaces and extends on top of Given the following input: It will generate the following code: /*===- TableGen'erated file -------------------------------------*- C++ -*-===*\
|* *|
|* Dialect Interface Declarations *|
|* *|
|* Automatically generated file, do not edit! *|
|* *|
\*===----------------------------------------------------------------------===*/
namespace mlir {
/// Define a base inlining interface class to allow for dialects to opt-in to the inliner.
class DialectInlinerInterface : public ::mlir::DialectInterface::Base<DialectInlinerInterface> {
public:
/// Returns true if the given region 'src' can be inlined into the region
/// 'dest' that is attached to an operation registered to the current dialect.
/// 'valueMapping' contains any remapped values from within the 'src' region.
/// This can be used to examine what values will replace entry arguments into
/// the 'src' region, for example.
virtual bool isLegalToInline(::mlir::Region * dest, ::mlir::Region * src, ::mlir::IRMapping & valueMapping) const;
protected:
DialectInlinerInterface(::mlir::Dialect *dialect) : Base(dialect) {}
};
} // namespace mlir
bool ::mlir::DialectInlinerInterface::isLegalToInline(::mlir::Region * dest, ::mlir::Region * src, ::mlir::IRMapping & valueMapping) const {
return true;
}Full diff: https://github.com/llvm/llvm-project/pull/170046.diff 6 Files Affected:
diff --git a/mlir/include/mlir/IR/Interfaces.td b/mlir/include/mlir/IR/Interfaces.td
index 0cbe3fa25c9e7..746ad3408f424 100644
--- a/mlir/include/mlir/IR/Interfaces.td
+++ b/mlir/include/mlir/IR/Interfaces.td
@@ -147,6 +147,11 @@ class TypeInterface<string name, list<Interface> baseInterfaces = []>
!if(!empty(cppNamespace),"", cppNamespace # "::") # name
>;
+// DialectInterface represents an interface registered to an operation.
+class DialectInterface<string name, list<Interface> baseInterfaces = []>
+ : Interface<name, baseInterfaces>, OpInterfaceTrait<name>;
+
+
// Whether to declare the interface methods in the user entity's header. This
// class simply wraps an Interface but is used to indicate that the method
// declarations should be generated. This class takes an optional set of methods
diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h
index 7c36cbc1192ac..f62d21da467a1 100644
--- a/mlir/include/mlir/TableGen/Interfaces.h
+++ b/mlir/include/mlir/TableGen/Interfaces.h
@@ -157,6 +157,13 @@ struct TypeInterface : public Interface {
static bool classof(const Interface *interface);
};
+// An interface that is registered to a Dialect.
+struct DialectInterface : public Interface {
+ using Interface::Interface;
+
+ static bool classof(const Interface *interface);
+};
+
} // namespace tblgen
} // namespace mlir
diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp
index b0ad3ee59a089..77a6cecebbeaf 100644
--- a/mlir/lib/TableGen/Interfaces.cpp
+++ b/mlir/lib/TableGen/Interfaces.cpp
@@ -208,3 +208,11 @@ bool OpInterface::classof(const Interface *interface) {
bool TypeInterface::classof(const Interface *interface) {
return interface->getDef().isSubClassOf("TypeInterface");
}
+
+//===----------------------------------------------------------------------===//
+// DialectInterface
+//===----------------------------------------------------------------------===//
+
+bool DialectInterface::classof(const Interface *interface) {
+ return interface->getDef().isSubClassOf("DialectInterface");
+}
diff --git a/mlir/test/mlir-tblgen/dialect-interface.td b/mlir/test/mlir-tblgen/dialect-interface.td
new file mode 100644
index 0000000000000..9b424bf501be3
--- /dev/null
+++ b/mlir/test/mlir-tblgen/dialect-interface.td
@@ -0,0 +1,66 @@
+// RUN: mlir-tblgen -gen-dialect-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+
+include "mlir/IR/Interfaces.td"
+
+def NoDefaultMethod : DialectInterface<"NoDefaultMethod"> {
+ let description = [{
+ This is an example dialect interface without default method body.
+ }];
+
+ let cppNamespace = "::mlir::example";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Check if it's an example dialect",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "isExampleDialect",
+ /*args=*/ (ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/ "second method to check if multiple methods supported",
+ /*returnType=*/ "unsigned",
+ /*methodName=*/ "supportSecondMethod",
+ /*args=*/ (ins "::mlir::Type":$type)
+ >
+
+ ];
+}
+
+// DECL: class NoDefaultMethod : public {{.*}}DialectInterface::Base<NoDefaultMethod>
+// DECL: virtual bool isExampleDialect() const = 0;
+// DECL: virtual unsigned supportSecondMethod(::mlir::Type type) const = 0;
+// DECL: protected:
+// DECL-NEXT: NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {}
+
+def WithDefaultMethodInterface : DialectInterface<"WithDefaultMethodInterface"> {
+ let description = [{
+ This is an example dialect interface with default method bodies.
+ }];
+
+ let cppNamespace = "::mlir::example";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Check if it's an example dialect",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "isExampleDialect",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{
+ return true;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ "second method to check if multiple methods supported",
+ /*returnType=*/ "unsigned",
+ /*methodName=*/ "supportSecondMethod",
+ /*args=*/ (ins "::mlir::Type":$type)
+ >
+
+ ];
+}
+
+// DECL: virtual bool isExampleDialect() const;
+// DECL: bool ::mlir::example::WithDefaultMethodInterface::isExampleDialect() const {
+// DECL-NEXT: return true;
+// DECL-NEXT: }
+
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index 2a7ef7e0576c8..d7087cba3c874 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -12,6 +12,7 @@ add_tablegen(mlir-tblgen MLIR
AttrOrTypeFormatGen.cpp
BytecodeDialectGen.cpp
DialectGen.cpp
+ DialectInterfacesGen.cpp
DirectiveCommonGen.cpp
EnumsGen.cpp
EnumPythonBindingGen.cpp
diff --git a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
new file mode 100644
index 0000000000000..2fc500343501c
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
@@ -0,0 +1,176 @@
+//===- DialectInterfacesGen.cpp - MLIR dialect interface utility generator ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// DialectInterfaceGen generates definitions for Dialect interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CppGenUtilities.h"
+#include "DocGenUtilities.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Interfaces.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/CodeGenHelpers.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using namespace mlir;
+using llvm::Record;
+using llvm::RecordKeeper;
+using mlir::tblgen::Interface;
+using mlir::tblgen::InterfaceMethod;
+
+/// Emit a string corresponding to a C++ type, followed by a space if necessary.
+static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
+ type = type.trim();
+ os << type;
+ if (type.back() != '&' && type.back() != '*')
+ os << " ";
+ return os;
+}
+
+/// Emit the method name and argument list for the given method.
+static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
+ raw_ostream &os) {
+ os << name << '(';
+ llvm::interleaveComma(method.getArguments(), os,
+ [&](const InterfaceMethod::Argument &arg) {
+ os << arg.type << " " << arg.name;
+ });
+ os << ") const";
+}
+
+/// Get an array of all Dialect Interface definitions
+static std::vector<const Record *>
+getAllInterfaceDefinitions(const RecordKeeper &records) {
+ std::vector<const Record *> defs =
+ records.getAllDerivedDefinitions("DialectInterface");
+
+ llvm::erase_if(defs, [&](const Record *def) {
+ // Ignore interfaces defined outside of the top-level file.
+ return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
+ llvm::SrcMgr.getMainFileID();
+ });
+ return defs;
+}
+
+namespace {
+/// This struct is the generator used when processing tablegen dialect
+/// interfaces.
+class DialectInterfaceGenerator {
+public:
+ DialectInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
+ : defs(getAllInterfaceDefinitions(records)), os(os) {}
+
+ bool emitInterfaceDecls();
+
+protected:
+ void emitInterfaceDecl(const Interface &interface);
+ void emitInterfaceMethodsDef(const Interface &interface);
+
+ /// The set of interface records to emit.
+ std::vector<const Record *> defs;
+ // The stream to emit to.
+ raw_ostream &os;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// GEN: Interface declarations
+//===----------------------------------------------------------------------===//
+
+static void emitInterfaceMethodDoc(const InterfaceMethod &method,
+ raw_ostream &os, StringRef prefix = "") {
+ if (std::optional<StringRef> description = method.getDescription())
+ tblgen::emitDescriptionComment(*description, os, prefix);
+}
+
+static void emitInterfaceDeclMethods(const Interface &interface,
+ raw_ostream &os) {
+ for (auto &method : interface.getMethods()) {
+ emitInterfaceMethodDoc(method, os, " ");
+ os << " virtual ";
+ emitCPPType(method.getReturnType(), os);
+ emitMethodNameAndArgs(method, method.getName(), os);
+ if (!method.getBody())
+ // no default method body
+ os << " = 0";
+ os << ";\n";
+ }
+}
+
+void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
+ llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
+
+ StringRef interfaceName = interface.getName();
+
+ tblgen::emitSummaryAndDescComments(os, "",
+ interface.getDescription().value_or(""));
+
+ // Emit the main interface class declaration.
+ os << llvm::formatv(
+ "class {0} : public ::mlir::DialectInterface::Base<{0}> {{\n"
+ "public:\n",
+ interfaceName);
+
+ emitInterfaceDeclMethods(interface, os);
+ os << llvm::formatv("\nprotected:\n"
+ " {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n",
+ interfaceName);
+
+ os << "};\n";
+}
+
+void DialectInterfaceGenerator::emitInterfaceMethodsDef(
+ const Interface &interface) {
+
+ for (auto &method : interface.getMethods()) {
+ if (auto body = method.getBody()) {
+ emitCPPType(method.getReturnType(), os);
+ os << interface.getCppNamespace() << "::";
+ os << interface.getName() << "::";
+ emitMethodNameAndArgs(method, method.getName(), os);
+ os << " {\n " << body.value() << "\n}\n";
+ }
+ }
+}
+
+bool DialectInterfaceGenerator::emitInterfaceDecls() {
+
+ llvm::emitSourceFileHeader("Dialect Interface Declarations", os);
+
+ // Sort according to ID, so defs are emitted in the order in which they appear
+ // in the Tablegen file.
+ std::vector<const Record *> sortedDefs(defs);
+ llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
+ return lhs->getID() < rhs->getID();
+ });
+
+ for (const Record *def : sortedDefs)
+ emitInterfaceDecl(Interface(def));
+
+ os << "\n";
+ for (const Record *def : sortedDefs)
+ emitInterfaceMethodsDef(Interface(def));
+
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: Interface registration hooks
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration genDecls(
+ "gen-dialect-interface-decls",
+ "Generate dialect interface declarations.",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return DialectInterfaceGenerator(records, os).emitInterfaceDecls();
+ });
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
joker-eph
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you convert an interface like DialectInlinerInterface for example here so we have a concrete working example as well.
I converted |
joker-eph
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks :)
joker-eph
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I suspect we need to update the documentation as well, can you do this?
813b84b to
3f0fc8b
Compare
|
@joker-eph I updated the documentation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: newline above (certain markdown renders require separation between text and codeblock)
mlir/docs/Interfaces.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: the spacing here is different from the other interfaces, lets keep to something like how mlir/include/mlir/Interfaces/ControlFlowInterfaces.td does it.
mlir/docs/Interfaces.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With that spacing change I think this fits on previous line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Newline above. You could perhaps even point to example in code / cmake usage.
mlir/docs/Interfaces.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't this be --gen-dialect-interface-decls?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space before (
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No change needed, but you could also use indented ostream here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used indented ostream. Can you please review again to see if that's the correct usage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could have sworn I had this exposed somewhere :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I grabbed this one from OpInterfacesGen.cpp. I haven't done any research if this has been implemented somewhere else. Do you remember where it might be implemented?
Currently, Dialect Interfaces can't be defined in ODS. This PR adds the support for dialect interfaces. It follows the same approach with other interfaces and extends on top of
Interfaceclass defined inmlir/TableGen/Interfaces.h.Given the following input:
It will generate the following code: